from fleming.hta.myNumPy import linspace
from fleming.hta.keypress_functions import getkey
from fleming.hta.call_function_sequence import header_output_Gen, header_Gen
import inspect
import socket
import csv
from datetime import datetime
from fleming.hta.static_functions import *
import threading
from fleming.hta.keypress_functions import getkey
import time


### store a axis for scanning:
# AxisScan
# properties:
#     movor:  a function for moving the axis
#     steps: the steps that the axis should travel through
#     current: current position in steps
    
# methods:
#     setMovor()
#     setRange()
#     move()
#     move2next()
#     nsteps(): number of steps
#     currentIndex():
#     currrentPosition():
#     display(): not implemented yet.
#     
#     by Haisen Ta @ PE, 20221001
###     
class AxisScan:
    def __init__(self, *args):
        if len(args)== 3: # move, range, nstep
            self.movor = args[0]
            self.steps = linspace(args[1][0],args[1][1],args[2])
        elif len(args)==2:
            self.movor = args[0]
            self.steps = args[1]
        elif len(args):
            print("invalid arguments")
        # make sure movor is a function, and steps are a list
        self.currentIndex=-1
        
    def setMovor(self, funcs):
        self.movor=funcs
        self.currentIndex=-1
        

    def setRange(self, *args):
        # setRange([left, right], n) as linspace
        if len(args)== 2: # move, range, nstep
            self.steps = linspace(args[0][0],args[0][1],args[1])
        elif len(args)==1:
            self.steps = args[0]
        elif len(args):
            print("invalid arguments")
        self.currentIndex=-1
    
    async def move(self,*args):  # move(self,index=self.currentIndex) does not change the currentindex
        if len(args)>0:
            index=args[0]
        else:
            index=self.currentIndex
        if inspect.iscoroutinefunction(self.movor):
        # if inspect.iscoroutinefunction(d):
            await self.movor(self.steps[index])
        else:
            self.movor(self.steps[index])
        # self.nextIndex=args[0]
        # self.currentIndex=(index+1)%self.nsteps()
        
    async def move2next(self): # move & update the currentindex to next
        current=(self.currentIndex+1)%self.nsteps()
        # if current>=len(self.steps):
        #     current=0
        await self.move(current)
        self.currentIndex=current
        return current
    
    def nsteps(self):
        return len(self.steps)
    
    def currentPosition(self):
        return self.steps[self.currentIndex]

    def currentIndex(self):
        return self.currentIndex

    def display(self, delimiter=";"): # display the important things.
        txt=self.movor.__name__
        txt+=delimiter
        txt+=delimiter.join([f"{x}" for x in self.steps])  # [f"{x}" for x in d[1:]]


### store a axis for scanning:
# Scan
# properties:
#     axis:  
#     steps: return all steps of all axis
#     prefuncs: prefuncs
#     funcs:
    
# methods:
#     currentIndex():
#     currrentPosition():
#     display(): not implemented yet.
#     run():
#     next():
#     ind2subind()
#     subind2ind()
#     move2current()
# 
#     by Haisen Ta @ PE, 20221001
###     
class Scan():
    def __init__(self, *args):
        if len(args)==1: 
                self.axis = args[0]
                self.prefuncs = [None]*len(self.axis)
                self.postfuncs = [None]*len(self.axis)
        elif len(args)==2:
                self.axis = args[0]
                self.prefuncs = args[1]
                self.postfuncs = [None]*len(self.axis)
        elif len(args)==3:
                self.axis = args[0]
                self.prefuncs = args[1]
                self.postfuncs = args[2]
        else:
                print("invalid arguments")
    
        self.currentIndex = [-1]*len(self.axis)
        self.dims=self.viable()
    
    def viable(self): # check if the dimensions is consistant in each axis
        # axis at the same dimension got the same number of steps or not.
        dims=[0]*len(self.axis) # dimensions
        for ax,i in zip(self.axis,range(len(self.axis))):
            for a in ax: # only erlier axis reach the edge, move to the next, otherwise just move to the current place or even skip.
                if dims[i]<a.nsteps():
                    if dims[i]>0:
                        print("error! dimension can be either 0, 1 or max!") # dimension zero makes sense?
                    dims[i]=a.nsteps()
        return dims
    
    def steps(self,*args): # return all steps, axis index can be given
        if len(args)>0:
            dim=args[0]
        else:
            dim=range(len(self.axis))
        
        steps=[]
        for ax in self.axis[dim]:
            s=[]
            for a in ax:
                s.append(a.steps)
            steps.append(s)
        return steps
    
    
    # def currentIndex(self):
    #     return self.ind2subind(self.subind2ind(self.currentIndex)-1)

    def currentPos(self): # not tested
        index=self.currentIndex
        # pos = [[]]*len(self.axis)
        # for ax,i in zip(self.axis,range(len(pos))):
            # for a in ax:
                # pos[i].append(a.steps[index[i]])
        pos = []
        for ax,i in zip(self.axis,index):
            for a in ax:
                pos.append(a.steps[i])
        return pos
    
    def functions(self,funcs):
        self.funcs=funcs
    
    def lenprod(self):
        prod=1
        for d in self.dims:
            prod*=d
        return prod

    async def header(self,*args):
        if len(args)>0:
            funcs=args[0]
            output,h,info=await header_output_Gen(funcs,True)
        # f.write(info+"\n")
        # f.write("time\t"+h+"\n")
            return info+"\n"
        else:
            funcs=self.funcs
            output,h,info=await header_output_Gen(funcs,True)
        # f.write(info+"\n")
        # f.write("time\t"+h+"\n")
            return "time\t"+h+"\n"
    
    async def fullheader(self,*args):
        if len(args)>0:
            funcs=args[0]
        else:
            funcs=self.funcs
        h0 = await header_Gen(funcs,flagRunCode=True)
        return h0+"\n"

    async def move2current(self):
        for ax,i in zip(self.axis,range(len(self.axis))):
            for a in ax: # only erlier axis, when reach the edge, move to the next, otherwise just move to the current place or even skip.
                c = await a.move()
                # self.nextIndex[i] = c               
    
    async def next(self):
        funcs=self.funcs
        currentIndex=self.currentIndex*1 # hard copy
        nextIndex=self.ind2subind(self.subind2ind(self.currentIndex)+1)
        for i0,i1,pre,ax in zip(currentIndex,nextIndex,self.prefuncs,self.axis): 
            if (i0 != i1):
                if  not (pre is None): # changing index == new lines
                    for d in pre:
                        if inspect.iscoroutinefunction(d[0]):
                        # if inspect.iscoroutinefunction(d):
                            out = await d[0](*d[1:len(d)])
                        else:
                            out=d[0](*d[1:len(d)])
                         # out does not to be saved or save it in a seperate file?
                for a in ax:
                    c = await a.move2next()
        self.currentIndex = nextIndex
        output=await header_output_Gen(funcs,False)
        
        for i0,i1,post in zip(nextIndex,self.currentIndex,self.postfuncs): # start of a line
            if (i0 != i1) and ( not (post is None) ):
                for d in post:
                    if inspect.iscoroutinefunction(d[0]):
                        out = await d[0](*d[1:len(d)])
                    else:
                        out=d[0](*d[1:len(d)])
        
        return output
    

    async def run(self,fname="test.csv",cmt="comment", prefuncs=[]):
        with open(fname,'a') as f:
            f.write(f"{cmt}:{socket.gethostname()}\n")
            f.write(f"precalled func:")
            writer=csv.writer(f, delimiter=';', dialect='excel-tab')
            h0 = await self.fullheader(prefuncs)
            f.write(h0+"\n")
        
        # actual core functions with or without header
        # counter=0
        with open(fname,'a') as f:
            writer=csv.writer(f, delimiter='\t', dialect='excel-tab')
            str=await self.header()
            f.write(str)

            thread1 = AxisThreadHelper(1, "Thread-1", 5)
            thread1.start()
            
            start_time = time.time()
            
            for i in range(self.lenprod()):
                output=await self.next()
                tstr=datetime.now().strftime("%y.%m.%d.%H.%M.%S.%f")
                writer.writerows([list([tstr,*output])])
                # print(counter)
                ratio=(i+1)/self.lenprod()
                print(f"{i}: {int(ratio*100)}%\t\t estimated remainig time: {(time.time()-start_time)*(1-ratio)/ratio:3.1f} s")
                print(output)
                # counter+=1
                print()
                tf=static_boolen('break')
                if tf:
                    # clean the thread. 
                    break

                
    def ind2subind(self,ind):
        dims=self.viable()
        t=[1]*(len(dims)+1)
        for d,i in zip(dims,range(len(dims))):
            t[i+1]=t[i]*d
        index=[0]*len(t)
        t=t[1:]
        left=ind
        for i in range(len(t),0,-1):
            index[i]=left//t[i-1]
            left=left%t[i-1]
        index[0]=left
        index = index[0:-1]
        return index
        
    def subind2ind(self,subind):
        dims=self.viable()
        cumpd=[1]*(len(dims)+1)
        for d,i in zip(dims,range(len(dims))):
            cumpd[i+1]=cumpd[i]*d
        # cumpd=cumpd[1:]
        ind=0
        for i,dm,cum in zip(subind,dims,cumpd):
            ind+=i%dm*cum
        return ind
        
    def ranges(self):
        pass

    def display(self):
        txt=""
        for ax in self.axis:
            for a in ax:
                txt+=a.display()
        print(txt)
        return txt


### store a axis for scanning:
# AxisThreadHelper
# properties:
#     threadID:  
#     name: 
#     counter: 
# 
# methods:
#     refer to threading.Thread
# 
#     by Haisen Ta @ PE, 20221001
###     
class AxisThreadHelper(threading.Thread):
    def __init__(self, threadID, name, counter):
        threading.Thread.__init__(self)
        self.threadID = threadID
        self.name = name
        self.counter = counter
        static_boolen('break',False)
    
    def run(self):
        print("Starting " + self.name)
        setflag()
        print("Exiting " + self.name)

def setflag():
    k = getkey()
    if k == 'a':
        static_boolen('break', True)
        print('measuremnt interrupted by user: "a"')
    return



# fleming.hta.AxisScan.reload
def reload():
    pass


